{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "4oQJ4NSK5tMq"
   },
   "source": [
    "# Quick start\n",
    "\n",
    "MindNLP offers powerful functionalities for training and using AI models for various tasks. To get started, this tutorial will guide you through loading a pretrained model and fine-tuning it to fit your specific needs.\n",
    "\n",
    "Using a pretrained model has great benefits: it saves computing time and resources. Fine-tuning allows you to adapt these models for enhanced performance on your unique dataset. Now that you're ready, let's get started!\n",
    "\n",
    "We will use the [BERT](https://arxiv.org/abs/1810.04805) model as an example and fine-tune it to perform classification task on the [Large Movie Review Dataset](https://huggingface.co/datasets/stanfordnlp/imdb).\n",
    "\n",
    "To perform the fine-tuning, MindNLP provides two approaches: one approach is through the user-friendly Trainer API from MindNLP, which supports essential training functionalities; To have more customized control, you can use the other approach through native MindSpore. We will guide you through both approaches in this tutorial.\n",
    "\n",
    "For both of the approches, you first need to prepare the dataset by running the [Prepare a dataset](#prepare_a_dataset) part of this tutorial.\n",
    "\n",
    "After dataset is ready, choose one of the trategies from below and start your journey!\n",
    "* [Fine-tune a pretrained model with MindNLP Trainer.](#train_with_mindnlp_trainer)\n",
    "* [Fine-tune a pretrained model in native MindSpore.](#train_with_native_mindspore)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Tsk4afqY5tMr"
   },
   "source": [
    "<a id='prepare_a_dataset'></a>\n",
    "## Prepare a dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FHiatYkn5tMt"
   },
   "source": [
    "Before you can fine-tune a pretrained model, download a dataset and prepare it for training.\n",
    "\n",
    "MindNLP includes a `load_dataset` API that loads any dataset from the Hugging Face dataset repository. Here let's use it to load the [Large Movie Review Dataset](https://huggingface.co/datasets/stanfordnlp/imdb) dataset, which is named `'imdb'`, and split it into training, validation and test datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "kD9Xs3Ke5tMx",
    "outputId": "9e90da89-f210-48a4-cde4-3063ce5c5a2d",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from mindnlp import load_dataset\n",
    "\n",
    "imdb_ds = load_dataset('imdb', split=['train', 'test'])\n",
    "imdb_train = imdb_ds['train']\n",
    "imdb_test = imdb_ds['test']\n",
    "\n",
    "# Split train dataset further into training and validation datasets\n",
    "imdb_train, imdb_val = imdb_train.split([0.7, 0.3])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rFimts_65tMy"
   },
   "source": [
    "Next, load the tokenizer for the model. The process of tokenization converts raw text into a format that machine learning models can process, which is crucial for natural language processing tasks.\n",
    "\n",
    "In MindNLP, `AutoTokenizer` helps automatically fetch and instantiate the appropriate tokenizer for a pre-trained model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "fkSS_Dkv5tMy"
   },
   "outputs": [],
   "source": [
    "from mindnlp.transformers import AutoTokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once the dataset and the tokenizer are ready, we can process the dataset.\n",
    "\n",
    "This includes\n",
    "* Tokenize the text.\n",
    "* Cast to correct datatype.\n",
    "* Handle variable sequence lengths with padding or truncation.\n",
    "* Shuffle the order of entries.\n",
    "* Batch the dataset.\n",
    "\n",
    "In the [Data_Preprocess](./1.data_preprocess.ipynb) tutorial, these steps will be elaborated.\n",
    "\n",
    "Here, we define the following `process_dataset` function to prepare the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore\n",
    "import numpy as np\n",
    "from mindspore.dataset import transforms\n",
    "\n",
    "def process_dataset(dataset, tokenizer, max_seq_len=256, batch_size=32, shuffle=False, take_len=None):\n",
    "    # The tokenize function\n",
    "    def tokenize(text):\n",
    "        tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)\n",
    "        return tokenized['input_ids'], tokenized['token_type_ids'], tokenized['attention_mask']\n",
    "\n",
    "    # Shuffle the order of the dataset\n",
    "    if shuffle:\n",
    "        dataset = dataset.shuffle(buffer_size=batch_size)\n",
    "\n",
    "    # Select the first several entries of the dataset\n",
    "    if take_len:\n",
    "        dataset = dataset.take(take_len)\n",
    "\n",
    "    # Apply the tokenize function, transforming the 'text' column into the three output columns generated by the tokenizer.\n",
    "    dataset = dataset.map(operations=[tokenize], input_columns=\"text\", output_columns=['input_ids', 'token_type_ids', 'attention_mask'])\n",
    "    # Cast the datatype of the 'label' column to int32 and rename the column to 'labels'\n",
    "    dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns=\"label\", output_columns=\"labels\")\n",
    "    # Batch the dataset with padding.\n",
    "    dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),\n",
    "                                                         'token_type_ids': (None, 0),\n",
    "                                                         'attention_mask': (None, 0)})\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now process all splits of the dataset and create smaller subsets of the datasets to shorten the process of the fine-tuning:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 4\n",
    "take_len = batch_size * 200\n",
    "small_dataset_train = process_dataset(imdb_train, tokenizer, batch_size=batch_size, shuffle=True, take_len=take_len)\n",
    "small_dataset_val = process_dataset(imdb_val, tokenizer, batch_size=batch_size, shuffle=True, take_len=take_len)\n",
    "small_dataset_test = process_dataset(imdb_test, tokenizer, batch_size=batch_size, shuffle=True, take_len=take_len)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here `take_len` is an optional parameter, which helps to create a smaller subset of the dataset to shorten the process of the fine-tuning.\n",
    "\n",
    "In practical fine-tuning jobs, however, the full dataset is normally used."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9-wrVSTe5tM2"
   },
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JrL2P5TQ5tM2"
   },
   "source": [
    "At this stage, you can choose either the [MindNLP Trainer API]((#train_with_mindnlp_trainer)) or the [native MindSpore](#train_with_native_mindspore) approach to fine-tune the model.\n",
    "\n",
    "Let's start with the Trainer API approach."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gkx8lgzD5tM3"
   },
   "source": [
    "<a id='train_with_mindnlp_trainer'></a>\n",
    "### Train with MindNLP Trainer\n",
    "\n",
    "MindNLP comes with a [`Trainer`](https://github.com/mindspore-lab/mindnlp/tree/master/mindnlp/engine/trainer) class designed to simplify model training. With `Trainer`, you can avoid the need to manually write your own training loop.\n",
    "\n",
    "`Trainer` supports a wide range of training options, which will be explained in the [Use Trainer](./2.use_trainer.ipynb) tutorial."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gkx8lgzD5tM3"
   },
   "source": [
    "#### Initialize the model\n",
    "In our example here, we will first instantiate the pretrained BERT model.\n",
    "\n",
    "For this purpose, we use `AutoModelForSequenceClassification`. Supply the name of the pretrained model, i.e. `'bert-base-cased'` to `AutoModelForSequenceClassification`. It will automatically infer the model architecture, instatiate the model and load the pretrained parameters. The model loaded here is a BERT model specialized in classification tasks, `BertForSequenceClassification`.\n",
    "\n",
    "To supply additional arguments to the model initialization, you can add more key-word arguments. Here, since the classification task involves determining whether a movie review expresses a positive or negative sentiment, we supply `num_labels=2` to the BERT model.\n",
    "\n",
    "For different types of tasks, MindNLP has a variety `AutoModel` classes to be chosen from."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "SM1APoY_5tM3"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The following parameters in checkpoint files are not loaded:\n",
      "['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']\n",
      "The following parameters in models are missing parameter:\n",
      "['classifier.weight', 'classifier.bias']\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'mindnlp.transformers.models.bert.modeling_bert.BertForSequenceClassification'>\n"
     ]
    }
   ],
   "source": [
    "from mindnlp.transformers import AutoModelForSequenceClassification\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained('bert-base-cased', num_labels=2)\n",
    "print(type(model))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mb761eVU5tM4"
   },
   "source": [
    "#### Training hyperparameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, create a `TrainingArguments` class where you can define the hyperparameters used in training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindnlp.engine import TrainingArguments\n",
    "training_args = TrainingArguments(\n",
    "    \"../../output\",\n",
    "    per_device_train_batch_size=1,\n",
    "    per_device_eval_batch_size=1,\n",
    "    learning_rate=2e-5,\n",
    "    num_train_epochs=3,\n",
    "    logging_steps=200,\n",
    "    evaluation_strategy='epoch',\n",
    "    save_strategy='epoch'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For a comprehensive understanding of more parameters in `TrainingArguments`, please refer to the [Use Trainer](./3.use_trainer.ipynb) tutorial. Here, we specified the following parameters.\n",
    "* `output_dir`: This is the directory where all outputs like model checkpoints and predictions will be saved. In this example, it is set to \"../../output\".\n",
    "* `per_device_train_batch_size`: This controls the batch size used for training on each device. Since we already batched our dataset, here the batch size is set to 1. If you want to use `Trainer`'s batch functionality, you can set your own batch size here.\n",
    "* `per_device_eval_batch_size`: Similar to the training batch size, but used during the evaluation phase on the validation data. Since we already batched our dataset, here the batch size is set to 1.\n",
    "* `learning_rate`: The rate at which the model learns. Smaller values mean slower learning, but they may lead to better model fine-tuning.\n",
    "* `num_train_epochs`: Defines how many times the training loop will run over the entire training dataset.\n",
    "* `evaluation_strategy`: Determines the strategy for performing evaluation. Setting it to 'epoch' means that the model is evaluated at the end of each training epoch.\n",
    "* `logging_steps`: This setting controls how often to log training loss and other metrics into the console. It helps in monitoring the training progress.\n",
    "* `save_strategy`: Determines the strategy for saving model checkpoints. Setting it to 'epoch' ensures that the model is saved at the end of every epoch."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Evaluate\n",
    "Evaluation is essential for understanding the model's performance and generalizability on new, unseen data.\n",
    "\n",
    "To enable evaluation of your model's performance during training, it's necessary to supply a function for metric compuation to `Trainer`.\n",
    "\n",
    "Here, we write a `compute_metrics` function, which will take an `EvalPrediction` object as input, and compute the evaluation metrics between the predictions and ground-truth labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import evaluate\n",
    "import numpy as np\n",
    "from mindnlp.engine.utils import EvalPrediction\n",
    "\n",
    "metric = evaluate.load(\"accuracy\")\n",
    "\n",
    "def compute_metrics(eval_pred: EvalPrediction):\n",
    "    logits, labels = eval_pred\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "    return metric.compute(predictions=predictions, references=labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Initialize the trainer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once the `TrainingArguments` instance is configured, you can pass it to the `Trainer` class along with your model and datasets. This setup allows the `Trainer` to utilize these arguments throughout the training and evaluation phases."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindnlp.engine import Trainer\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    train_dataset=small_dataset_train,\n",
    "    eval_dataset=small_dataset_val,\n",
    "    compute_metrics=compute_metrics,\n",
    "    args=training_args,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Start training\n",
    "Now we are all set, let's start training!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Use the trained model\n",
    "You can now use the trained model to predict on a simple example. We define a text, tokenize it and use it as model input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SequenceClassifierOutput(loss=None, logits=Tensor(shape=[1, 2], dtype=Float32, value=\n",
      "[[-2.99299073e+00,  2.96375227e+00]]), hidden_states=None, attentions=None)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from mindspore import Tensor, ops\n",
    "\n",
    "text = \"What an amusing movie!\"\n",
    "\n",
    "# Tokenize the text\n",
    "inputs = tokenizer(text, padding=True, truncation=True, max_length=256)\n",
    "ts_inputs = {key: Tensor(val).expand_dims(0) for key, val in inputs.items()}\n",
    "\n",
    "# Predict\n",
    "model.set_train(False)\n",
    "outputs = model(**ts_inputs)\n",
    "print(outputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " The outputs are logits, which can be converted to the probability that the given text belong to each category."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Negative sentiment: 0.0026\n",
      "Positive sentiment: 0.9974\n"
     ]
    }
   ],
   "source": [
    "# Convert predictions to probabilities\n",
    "predictions = ops.softmax(outputs.logits)\n",
    "probabilities = predictions.numpy().flatten()\n",
    "\n",
    "# Here first class is 'negative' and the second is 'positive'\n",
    "print(f\"Negative sentiment: {probabilities[0]:.4f}\")\n",
    "print(f\"Positive sentiment: {probabilities[1]:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4bDlBGGz5tM-"
   },
   "source": [
    "<a id='train_with_native_mindspore'></a>\n",
    "### Train in native MindSpore\n",
    "If you prefer to have more customized control over the training process, you can also fine-tune a in native MindSpore.\n",
    "\n",
    "If you went trough the [Train with MindNLP Trainer](#train_with_mindnlp_trainer) part, you may need to restart your notebook and re-run the [Prepare a dataset](#prepare_a_dataset) part, or execute the following code to free some memory:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "id": "e38pTNG95tM-"
   },
   "outputs": [],
   "source": [
    "# Free up memory by deleting model and trainer used in the Train with MindNLP Trainer step\n",
    "del model\n",
    "del trainer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5pOe5fXc5tNA"
   },
   "source": [
    "#### Load the model\n",
    "Load your model with the number of expected labels:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "id": "qNS3-kES5tNA"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The following parameters in checkpoint files are not loaded:\n",
      "['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']\n",
      "The following parameters in models are missing parameter:\n",
      "['classifier.weight', 'classifier.bias']\n"
     ]
    }
   ],
   "source": [
    "from mindnlp.transformers import AutoModelForSequenceClassification\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-cased\", num_labels=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_AwyEa8l5tNB"
   },
   "source": [
    "#### Optimizer and loss function\n",
    "Set up the optimizer, which updates the model parameters to minimize the loss function based on the computed gradients. Let's use the `AdamWeightDeday` optimizer from MindSpore:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "id": "dv5sNMau5tNB"
   },
   "outputs": [],
   "source": [
    "from mindspore.experimental.optim import AdamW\n",
    "\n",
    "optimizer = AdamW(model.trainable_params(), lr=5e-5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define the loss function, which quantifies the difference between the model's predictions and the actual target values. Here we use the cross-entropy loss function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore import ops\n",
    "loss_fn = ops.cross_entropy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Forward and Gradient Functions\n",
    "\n",
    "Define a forward function `forward_fn` to manage the forward pass of the model and compute the loss.\n",
    "\n",
    "Then make use of MindSpore's `value_and_grad`, and define a gradient function `grad_fn` to automatically compute both the loss and the gradients of this loss with respect to the model's parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore import value_and_grad\n",
    "from tqdm import tqdm\n",
    "\n",
    "def forward_fn(data, labels):\n",
    "    logits = model(**data).logits\n",
    "    loss = loss_fn(logits, labels)\n",
    "    return loss\n",
    "\n",
    "grad_fn = value_and_grad(forward_fn, None, optimizer.parameters)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Training step\n",
    "\n",
    "Implement a `train_step` function that will be excuted in each step of the training.\n",
    "\n",
    "This function processes a single batch of data, computes the loss and gradients, and updates the model parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_step(batch):\n",
    "    labels = batch.pop('labels')\n",
    "    loss, grads = grad_fn(batch, labels)\n",
    "    optimizer(grads)\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "760XlzSi5tND"
   },
   "source": [
    "#### Training loop for one epoch\n",
    "\n",
    "Implement a `train_one_epoch` function that trains the model for one epoch by iterating over all batches in the dataset:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "def train_one_epoch(model, train_dataset, epoch=0):\n",
    "    model.set_train(True)\n",
    "    total = train_dataset.get_dataset_size()\n",
    "    loss_total = 0\n",
    "    step_total = 0\n",
    "    with tqdm(total=total) as progress_bar:\n",
    "        progress_bar.set_description('Epoch %i' % epoch)\n",
    "        for batch in train_dataset.create_dict_iterator():\n",
    "            loss = train_step(batch)\n",
    "            loss_total += loss.asnumpy()\n",
    "            step_total += 1\n",
    "            progress_bar.set_postfix(loss=loss_total/step_total)\n",
    "            progress_bar.update(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before the training loop starts, `train_one_epoch` sets the model to the training mode by `model.set_train(True)`.\n",
    "\n",
    "In each iteration, the function calls `train_step` on the current batch of data.\n",
    "\n",
    "To keep track of the training progress, it also accumulates and displays the average loss across batches in a progress bar, providing a real-time view of the training progress during the epoch."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Evaluation\n",
    "\n",
    "Create a function to compute the accuracy of the model's predictions. Similar as in training with Trainer API, we make use of the evaluate package from Hugging Face."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "import evaluate\n",
    "import numpy as np\n",
    "\n",
    "metric = evaluate.load(\"accuracy\")\n",
    "\n",
    "def compute_accuracy(logits, labels):\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "    return metric.compute(predictions=predictions, references=labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Implement a function `evaluate_fn` to evaluate the model on a validation dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_fn(model, test_dataset, criterion, epoch=0):\n",
    "    total = test_dataset.get_dataset_size()\n",
    "    epoch_loss = 0\n",
    "    epoch_acc = 0\n",
    "    step_total = 0\n",
    "    model.set_train(False)\n",
    "\n",
    "    with tqdm(total=total) as progress_bar:\n",
    "        progress_bar.set_description('Epoch %i' % epoch)\n",
    "        for batch in test_dataset.create_dict_iterator():\n",
    "            label = batch.pop('labels')\n",
    "            logits = model(**batch).logits\n",
    "            loss = criterion(logits, label)\n",
    "            epoch_loss += loss.asnumpy()\n",
    "\n",
    "            acc = compute_accuracy(logits, label)['accuracy']\n",
    "            epoch_acc += acc\n",
    "\n",
    "            step_total += 1\n",
    "            progress_bar.set_postfix(loss=epoch_loss/step_total, acc=epoch_acc/step_total)\n",
    "            progress_bar.update(1)\n",
    "\n",
    "    return epoch_loss / total"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By the start of the evaluation, `evaluate_fn` disables the training mode by `model.set_train(False)`\n",
    "\n",
    "The function then iterates over all test batches. For each batch, it computes the logits, calculates the loss, and assesses the accuracy. These metrics are accumulated to provide average loss and accuracy for the epoch, which are displayed on a progress bar."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Training loop for all epochs\n",
    "\n",
    "Finally, we can excute the training that loops through each epoch and at the end of each epoch, evaluate the models' performance.\n",
    "\n",
    "When the validation performance is better than all previous epochs, the model parameters will be saved as checkpoint file for future use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [01:14<00:00,  2.67it/s, loss=0.591]\n",
      "Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:13<00:00, 14.59it/s, acc=0.744, loss=0.531]\n",
      "Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [01:14<00:00,  2.69it/s, loss=0.457]\n",
      "Epoch 1: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:13<00:00, 14.70it/s, acc=0.807, loss=0.498]\n",
      "Epoch 2: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [01:14<00:00,  2.70it/s, loss=0.259]\n",
      "Epoch 2: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:13<00:00, 14.70it/s, acc=0.82, loss=0.464]\n"
     ]
    }
   ],
   "source": [
    "import mindspore as ms\n",
    "num_epochs = 3\n",
    "best_valid_loss = float('inf')\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    train_one_epoch(model, small_dataset_train, epoch)\n",
    "    valid_loss = evaluate_fn(model, small_dataset_val, loss_fn, epoch)\n",
    "\n",
    "    if valid_loss < best_valid_loss:\n",
    "        best_valid_loss = valid_loss\n",
    "        ms.save_checkpoint(model, '../../sentiment_analysis.ckpt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Use the trained model\n",
    "If you are curious about how your trained model actually classifies text to its sentiment category, try the following code:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SequenceClassifierOutput(loss=None, logits=Tensor(shape=[1, 2], dtype=Float32, value=\n",
      "[[-1.56168151e+00,  2.04059339e+00]]), hidden_states=None, attentions=None)\n",
      "Negative sentiment: 0.0265\n",
      "Positive sentiment: 0.9735\n"
     ]
    }
   ],
   "source": [
    "# Predict on example\n",
    "import numpy as np\n",
    "from mindspore import Tensor, ops\n",
    "\n",
    "text = \"I am pretty convinced that the movie depicted the future of AI in an elegant way.\"\n",
    "\n",
    "# Encode the text to input IDs and attention masks\n",
    "inputs = tokenizer(text, padding=True, truncation=True, max_length=256)\n",
    "ts_inputs = {key: Tensor(val).expand_dims(0) for key, val in inputs.items()}\n",
    "\n",
    "# Predict\n",
    "model.set_train(False)\n",
    "outputs = model(**ts_inputs)\n",
    "print(outputs)\n",
    "\n",
    "# Convert predictions to probabilities\n",
    "predictions = ops.softmax(outputs.logits)\n",
    "probabilities = predictions.numpy().flatten()\n",
    "\n",
    "# Here first class is 'negative' and the second is 'positive'\n",
    "print(f\"Negative sentiment: {probabilities[0]:.4f}\")\n",
    "print(f\"Positive sentiment: {probabilities[1]:.4f}\")"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
